package com.bosssoft.platform.fasttcc.log.file;

import com.bosssoft.platform.fasttcc.*;
import com.bosssoft.platform.fasttcc.impl.XidImpl;
import com.bosssoft.platform.fasttcc.rpc.command.TccCommandHandler;
import com.jfireframework.baseutil.StringUtil;
import com.jfireframework.baseutil.TRACEID;
import com.jfireframework.baseutil.reflect.ReflectUtil;
import com.jfireframework.licp.Licp;
import com.jfireframework.licp.buf.ByteBuf;
import com.jfireframework.licp.buf.HeapByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.Charset;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

public class FileTccLogger implements TccLogger
{
    /**
     * 事务日志所在文件夹
     */
    private              String                    baseDir;
    private              DataSource                dataSource;
    private              ArchiveWorker             archiveWorker;
    private              Lock                      lock                  = new ReentrantLock();
    private              File                      logFile;
    private              RandomAccessFile          randomAccessFile;
    private              FileChannel               transactionLog;
    private              ByteBuffer                buffer                = ByteBuffer.allocateDirect(1024 * 1024);
    private              int                       writeCount            = 0;
    private              int                       threshold             = 1024 * 1024 * 4;
    private static       Charset                   charset               = Charset.forName("utf8");
    private              Map<TccOperation, byte[]> methodSignMap         = new IdentityHashMap<>();
    private              Licp                      licp                  = new Licp();
    private              ByteBuf                   paramSerializationBuf = HeapByteBuf.allocate(1024 * 1024);
    private              boolean                   init                  = false;
    private static final Logger                    LOGGER                = LoggerFactory.getLogger(FileTccLogger.class);

    public FileTccLogger(String baseDir, DataSource dataSource)
    {
        this.baseDir = baseDir;
        this.dataSource = dataSource;
        archiveWorker = new ArchiveWorker(baseDir, dataSource);
        if (new File(baseDir).exists() == false)
        {
            new File(baseDir).mkdirs();
        }
    }

    private void initIfNecessary()
    {
        if (init)
        {
            return;
        }
        init = true;
        logFile = new File(baseDir + File.separator + "TccTransaction.log");
        String traceId = TRACEID.currentTraceId();
        try
        {
            if (logFile.exists() == false)
            {
                logFile.createNewFile();
                LOGGER.debug("traceId:{} 创建事务日志文件:{}", traceId, logFile.getAbsolutePath());
            }
            randomAccessFile = new RandomAccessFile(logFile, "rw");
            randomAccessFile.seek(randomAccessFile.length());
            transactionLog = randomAccessFile.getChannel();
        }
        catch (IOException e)
        {
            ReflectUtil.throwException(e);
            LOGGER.debug("traceId:{} 出现未知异常", traceId, e);
        }
    }

    @Override
    public void createTccTransaction(final TccTransaction transaction)
    {
        doLogTemplate(new BufferWriter()
        {
            @Override
            public void write()
            {
                buffer.put(LogRecordSchema.CREATE_TCC_TRANSACTION.schema());
                buffer.put(transaction.getXid().getGlobalId());
                buffer.put(transaction.role());
                String propagatedBy = transaction.getPropagatedBy();
                short  len;
                byte[] value        = null;
                if (StringUtil.isNotBlank(propagatedBy))
                {
                    value = propagatedBy.getBytes(charset);
                    len = (short) value.length;
                }
                else
                {
                    len = 0;
                }
                buffer.putShort(len);
                if (len != 0)
                {
                    buffer.put(value);
                }
                LOGGER.debug("traceId:{} 新建TCC事务,xid:{},传播者标识:{}", TRACEID.currentTraceId(), StringUtil.toHexString(transaction.getXid().getGlobalId()), propagatedBy);
            }
        });
    }

    interface BufferWriter
    {
        void write();
    }

    private void doLogTemplate(BufferWriter writer)
    {
        lock.lock();
        initIfNecessary();
        try
        {
            writer.write();
            int position = buffer.position();
            writeToFile();
            AddWriteCount(position);
        }
        finally
        {
            lock.unlock();
        }
    }

    private void AddWriteCount(int count)
    {
        writeCount += count;
        if (writeCount > threshold)
        {
            if (archiveWorker.isIdle())
            {
                LOGGER.debug("事务日志写入量:{},已经超过阀值:{},启动归档者线程");
                try
                {
                    transactionLog.close();
                    randomAccessFile.close();
                    logFile.renameTo(new File(baseDir + File.separator + "Archiving.log"));
                    logFile = new File(baseDir + File.separator + "TccTransaction.log");
                    logFile.createNewFile();
                    randomAccessFile = new RandomAccessFile(logFile, "rw");
                    transactionLog = randomAccessFile.getChannel();
                    writeCount = 0;
                    archiveWorker.start();
                }
                catch (IOException e)
                {
                    ReflectUtil.throwException(e);
                }
            }
            else
            {
                LOGGER.debug("事务日志写入量:{},已经超过阀值:{},归档者线程在工作中，等待下次触发");
            }
        }
    }

    private void writeToFile()
    {
        try
        {
            buffer.flip();
            while (buffer.hasRemaining())
            {
                transactionLog.write(buffer);
            }
            buffer.clear();
//            transactionLog.force(false);
        }
        catch (IOException e)
        {
            ReflectUtil.throwException(e);
        }
    }

    @Override
    public void registerLocalTransaction(final LocalTransaction localTransaction)
    {
        doLogTemplate(new BufferWriter()
        {
            @Override
            public void write()
            {
                buffer.put(LogRecordSchema.REGISTER_lOCAL_TRANSACTION.schema());
                buffer.put(localTransaction.xid().getGlobalId());
                buffer.put(localTransaction.xid().getBranchId());
                buffer.put((byte) localTransaction.preLocalTxIndex());
                buffer.put((byte) localTransaction.txIndex());
            }
        });
    }

    @Override
    public void registerTccInvoke(final TccInvoke tccInvoke, final TccTransaction tccTransaction)
    {
        doLogTemplate(new BufferWriter()
        {
            @Override
            public void write()
            {
                buffer.put(LogRecordSchema.REGISTER_TCC_INVOKE.schema());
                buffer.put(tccTransaction.getXid().getGlobalId());
                byte[] methodSign = getMethodSign(tccInvoke);
                buffer.putShort((short) methodSign.length);
                buffer.put(methodSign);
                buffer.put((byte) tccInvoke.getAssociatedLocalTransaction().txIndex());
                buffer.put(tccInvoke.getCompleteStageXid().getBranchId());
                licp.serialize(tccInvoke.getParams(), paramSerializationBuf);
                byte[] content = paramSerializationBuf.toArray();
                paramSerializationBuf.clear();
                buffer.putInt(content.length);
                if (buffer.remaining() < content.length)
                {
                    ByteBuffer tmp = ByteBuffer.allocate(buffer.capacity() + content.length);
                    buffer.flip();
                    tmp.put(buffer);
                    buffer = tmp;
                }
                buffer.put(content);
            }
        });
    }

    private byte[] getMethodSign(TccInvoke tccInvoke)
    {
        TccOperation tccOperation = tccInvoke.getTccOperation();
        byte[]       sign         = methodSignMap.get(tccOperation);
        if (sign == null)
        {
            String name = tccInvoke.getTccOperation().getTryClass().getName();
            sign = (name + "#" + tccInvoke.getTccOperation().getTryMethod().getName()).getBytes(charset);
            methodSignMap.put(tccOperation, sign);
        }
        return sign;
    }

    @Override
    public void registerRemoteResource(final RemoteResource remoteResource, final Xid xid)
    {
        doLogTemplate(new BufferWriter()
        {
            @Override
            public void write()
            {
                buffer.put(LogRecordSchema.REGISTER_REMOTE_RESOURCE.schema());
                buffer.put(xid.getGlobalId());
                String identifier      = remoteResource.getIdentifier();
                byte[] identifierBytes = identifier.getBytes(charset);
                buffer.putShort((short) identifierBytes.length);
                buffer.put(identifierBytes);
            }
        });
    }

    @Override
    public void updateTccTransactionState(final TccTransaction transaction)
    {
        doLogTemplate(new BufferWriter()
        {
            @Override
            public void write()
            {
                buffer.put(LogRecordSchema.UPDATE_TRANSACTION_STATE.schema());
                buffer.put(transaction.getXid().getGlobalId());
                buffer.put((byte) transaction.getState());
            }
        });
    }

    @Override
    public void recover(TccTransactionManager tccTransactionManager, TccCommandHandler tccCommandHandler)
    {
        String traceId   = TRACEID.currentTraceId();
        File   archiving = new File(baseDir + File.separator + "Archiving.log");
        if (archiving.exists() == false)
        {
            File tmpFile = new File(baseDir + File.separator + "Archived.log.tmp");
            if (tmpFile.exists())
            {
                LOGGER.debug("traceId:{} 存在Archived.log.tmp文件。首先删除同目录下Archived.log文件", traceId);
                new File(baseDir + File.separator + "Archived.log").deleteOnExit();
                tmpFile.renameTo(new File(baseDir + File.separator + "Archived.log"));
            }
        }
        Collection<ArchiveWorker.TccTransactionInfo> exists;
        do
        {
            new File(baseDir + File.separator + "Archived.log.tmp").deleteOnExit();
            exists = archiveWorker.doArchive();
            LOGGER.debug("traceId:{} 宕机恢复，仍需要执行的事务有：{}个", traceId, exists.size());
            File tccTransactionFile = new File(baseDir + File.separator + "TccTransaction.log");
            if (tccTransactionFile.exists())
            {
                if (tccTransactionFile.renameTo(new File(baseDir + File.separator + "Archiving.log")) == false)
                {
                    throw new IllegalStateException();
                }
                LOGGER.debug("traceId:{} 存在事务日志文件，重名为Archiving.log，继续归档流程", traceId);
            }
            else
            {
                break;
            }
        } while (true);
        if (exists.isEmpty())
        {
            LOGGER.debug("traceId:{} 当前不存在未完成的TCC事务，宕机恢复结束", traceId);
            return;
        }
        for (ArchiveWorker.TccTransactionInfo each : exists)
        {
            TccTransaction tccTransaction = recoverTccTransaction(dataSource, tccTransactionManager, each);
            if (tccTransaction.getState() == TccTransaction.ACTIVE)
            {
                if (tccTransaction.role() == TccTransaction.COORDINATOR)
                {
                    if (tccTransaction.isFirstLocalTransactionCommited())
                    {
                        tccTransaction.markForCommit();
                    }
                    else
                    {
                        tccTransaction.markForRollback();
                    }
                    tccTransaction.processCompleteStage();
                }
                else
                {
                    tccCommandHandler.addTransaction(tccTransaction);
                }
            }
            else
            {
                tccTransaction.processCompleteStage();
            }
        }
    }

    private void deleteLogTableItem(DataSource dataSource, Collection<ArchiveWorker.TccTransactionInfo> deletes)
    {
        String traceId = TRACEID.currentTraceId();
        try (Connection connection = dataSource.getConnection(); PreparedStatement preparedStatement = connection.prepareStatement("delete from jfiretcc where global_id=?");)
        {
            for (ArchiveWorker.TccTransactionInfo each : deletes)
            {
                String globalId = StringUtil.toHexString(each.getXid().getGlobalId());
                preparedStatement.setString(1, globalId);
                preparedStatement.addBatch();
                LOGGER.debug("traceId:{} 从日志表删除记录:{}", traceId, globalId);
            }
            preparedStatement.executeBatch();
        }
        catch (SQLException e)
        {
            ReflectUtil.throwException(e);
            LOGGER.debug("traceId:{} 日志表删除数据出现未知异常", traceId, e);
        }
    }

    private TccTransaction recoverTccTransaction(DataSource dataSource, TccTransactionManager tccTransactionManager, ArchiveWorker.TccTransactionInfo tccTransactionInfo)
    {
        TccTransaction         tccTransaction    = recoverTccTransaction(tccTransactionManager, tccTransactionInfo);
        List<LocalTransaction> localTransactions = recoverLocalTransactions(tccTransactionManager, tccTransactionInfo.getLocalTransactionInfos(), dataSource, tccTransaction);
        recoverTccInvokes(tccTransactionManager, tccTransactionInfo.getTccInvokeInfos(), tccTransaction, localTransactions, dataSource);
        recoverRemoteResources(tccTransactionManager, tccTransactionInfo.getRemoteResourceInfos(), tccTransaction);
        return tccTransaction;
    }

    private void recoverRemoteResources(TccTransactionManager tccTransactionManager, List<ArchiveWorker.RemoteResourceInfo> list, TccTransaction tccTransaction)
    {
        List<RemoteResource> remoteResources = new ArrayList<>();
        for (ArchiveWorker.RemoteResourceInfo remoteResourceInfo : list)
        {
            byte[]         identifierBytes = remoteResourceInfo.getIdentifierBytes();
            String         identifier      = new String(identifierBytes, charset);
            RemoteResource remoteResource  = tccTransactionManager.reConstructRemoteResource(tccTransaction, identifier);
            remoteResources.add(remoteResource);
        }
        tccTransactionManager.resetRemoteResources(tccTransaction, remoteResources);
    }

    private void recoverTccInvokes(TccTransactionManager tccTransactionManager, List<ArchiveWorker.TccInvokeInfo> list, TccTransaction tccTransaction, List<LocalTransaction> localTransactions, DataSource dataSource)
    {
        ByteBuf         buf        = HeapByteBuf.allocate(1024 * 1024);
        List<TccInvoke> tccInvokes = new ArrayList<>();
        for (ArchiveWorker.TccInvokeInfo tccInvokeInfo : list)
        {
            byte[] params                   = tccInvokeInfo.getParams();
            int    associativeTxIndex       = tccInvokeInfo.getAssociativeTxIndex();
            byte[] completeStageXidBranchId = tccInvokeInfo.getCompleteStageXidBranchId();
            byte[] signBytes                = tccInvokeInfo.getSignBytes();
            buf.clear().put(params);
            Object[] paramArray = licp.deserialize(buf);
            String   methodSign = new String(signBytes, charset);
            String   className  = methodSign.substring(0, methodSign.indexOf("#"));
            String   methodName = methodSign.substring(methodSign.indexOf("#") + 1);
            Method   method     = null;
            try
            {
                for (Method each : Class.forName(className).getDeclaredMethods())
                {
                    if (each.getName().equals(methodName) && each.isAnnotationPresent(Tcc.class))
                    {
                        method = each;
                        break;
                    }
                }
                if (method == null)
                {
                    throw new NullPointerException("无法找到对应的方法，接口方法签名为：" + methodSign);
                }
            }
            catch (Throwable e)
            {
                ReflectUtil.throwException(e);
            }
            TccInvoke tccInvoke = tccTransactionManager.reConstructTccInvoke(tccTransaction, paramArray, method, localTransactions.get(associativeTxIndex), completeStageXidBranchId);
            tccInvokes.add(tccInvoke);
            if (tccTransaction.getState() != TccTransaction.ACTIVE)
            {
                if (tccTransaction.getState() == TccTransaction.MARK_FOR_COMMIT)
                {
                    fetchCompleteStateFromDB(dataSource, tccInvoke);
                }
                else if (tccTransaction.getState() == TccTransaction.MARK_FOR_ROLLBACK)
                {
                    //没有提交，也就不需要走取消方法，相当于完成了取消分支
                    if (tccInvoke.getAssociatedLocalTransaction().isRollbacked())
                    {
                        tccInvoke.markCompleted();
                    }
                    else
                    {
                        fetchCompleteStateFromDB(dataSource, tccInvoke);
                    }
                }
            }
        }
        tccTransactionManager.setsetTccInvokes(tccTransaction, tccInvokes);
    }

    private void fetchCompleteStateFromDB(DataSource dataSource, TccInvoke tccInvoke)
    {
        String traceId = TRACEID.currentTraceId();
        try (Connection connection = dataSource.getConnection();//
             PreparedStatement preparedStatement = connection.prepareStatement("select * from fasttcc where global_id=? and branch_id=?")//
        )
        {
            preparedStatement.setString(1, StringUtil.toHexString(tccInvoke.getCompleteStageXid().getGlobalId()));
            preparedStatement.setString(2, StringUtil.toHexString(tccInvoke.getCompleteStageXid().getBranchId()));
            ResultSet resultSet = preparedStatement.executeQuery();
            if (resultSet.next())
            {
                LOGGER.debug("traceId:{} tcc调用:{}从日志表查询，完成阶段已经执行完毕", traceId, tccInvoke.getCompleteStageXid());
                tccInvoke.markCompleted();
            }
            else
            {
                LOGGER.debug("traceId:{} tcc调用:{}从日志表查询，完成阶段尚未执行成功", traceId, tccInvoke.getCompleteStageXid());
            }
        }
        catch (SQLException e)
        {
            LOGGER.debug("traceId:{} 恢复TCC调用的完成状态出现未知错误", traceId);
        }
    }

    private List<LocalTransaction> recoverLocalTransactions(TccTransactionManager tccTransactionManager, List<ArchiveWorker.LocalTransactionInfo> list, DataSource dataSource, TccTransaction tccTransaction)
    {
        List<LocalTransaction> localTransactions = new ArrayList<>();
        for (ArchiveWorker.LocalTransactionInfo localTransactionInfo : list)
        {
            byte[]           branchId         = localTransactionInfo.getBranchId();
            int              preTxIndex       = localTransactionInfo.getPreTxIndex();
            int              txIndex          = localTransactionInfo.getTxIndex();
            LocalTransaction localTransaction = tccTransactionManager.reConstructLocalTransaction(tccTransaction, branchId, preTxIndex, txIndex);
            localTransactions.add(txIndex, localTransaction);
        }
        recoverLocalTransactionState(dataSource, localTransactions);
        tccTransactionManager.resetLocalTransactions(tccTransaction, localTransactions);
        return localTransactions;
    }

    private TccTransaction recoverTccTransaction(TccTransactionManager tccTransactionManager, ArchiveWorker.TccTransactionInfo each)
    {
        Xid xid = new XidImpl();
        ((XidImpl) xid).setGlobalId(each.getXid().getGlobalId());
        int    role                   = each.getRole();
        int    state                  = each.getState();
        byte[] propagatedByValueBytes = each.getPropagatedByValueBytes();
        String propagatedBy           = new String(propagatedByValueBytes, charset);
        return tccTransactionManager.reConstruct(xid, role, state, propagatedBy);
    }

    private void recoverLocalTransactionState(DataSource dataSource, List<LocalTransaction> localTransactions)
    {
        try (Connection connection = dataSource.getConnection())
        {
            PreparedStatement preparedStatement = connection.prepareStatement("select * from fasttcc where global_id=? and branch_id=?");
            for (LocalTransaction localTransaction : localTransactions)
            {
                preparedStatement.setString(1, StringUtil.toHexString(localTransaction.xid().getGlobalId()));
                preparedStatement.setString(2, StringUtil.toHexString(localTransaction.xid().getBranchId()));
                ResultSet resultSet = preparedStatement.executeQuery();
                if (resultSet.next())
                {
                    localTransaction.setCommited();
                }
                else
                {
                    localTransaction.setRollbacked();
                }
            }
        }
        catch (SQLException e)
        {
            ReflectUtil.throwException(e);
        }
    }
}
